将“softmax+交叉熵”推广到多标签分类问题
©PaperWeekly 原创 · 作者|苏剑林
单位|追一科技
▲ 类别不平衡
这时候,一个很自然的困惑就是:为什么“n选k”要比“n选1”多做那么多工作?
笔者认为这是很不科学的事情,毕竟直觉上 n 选 k 应该只是 n 选 1 自然延伸,所以不应该要比 n 要多做那么多事情,就算 n 选 k 要复杂一些,难度也应该是慢慢过渡的,但如果变成多个二分类的话,n 选 1 反而是最难的,因为这时候类别最不均衡。
而从形式上来看,单标签分类比多标签分类要容易,就是因为单标签有 “softmax + 交叉熵”可以用,它不会存在类别不平衡的问题,而多标签分类中的 “sigmoid + 交叉熵”就存在不平衡的问题。
为了考虑这个推广,笔者进行了多次尝试,也否定了很多结果,最后确定了一个相对来说比较优雅的方案:构建组合形式的 softmax 来作为单标签 softmax 的推广。
首先,我们考虑 k 是一个固定常数的情景,这意味着预测的时候,我们直接输出得分最高的 k 个类别即可。那训练的时候呢?作为 softmax 的自然推广,我们可以考虑用下式作为 loss:
看上去“众里寻她千百度”终究是有了结果:理论有了,实现也不困难,接下来似乎就应该做实验看效果了吧?效果好的话,甚至可以考虑发 paper 了吧?看似一片光明前景呢!然而~
幸运或者不幸,在验证了它的有效性的同时,笔者请教了一些前辈大神,在他们的提示下翻看了之前没细看的 Circle Loss [5] ,看到了它里边统一的 loss 形式(原论文的公式 (1)),然后意识到了这个统一形式蕴含了一个更简明的推广方案。
让我们换一种形式看单标签分类的交叉熵 (1):
def multilabel_categorical_crossentropy(y_true, y_pred):
"""多标签分类的交叉熵
说明:y_true和y_pred的shape一致,y_true的元素非0即1,
1表示对应的类为目标类,0表示对应的类为非目标类。
"""
y_pred = (1 - 2 * y_true) * y_pred
y_pred_neg = y_pred - y_true * 1e12
y_pred_pos = y_pred - (1 - y_true) * 1e12
zeros = K.zeros_like(y_pred[..., :1])
y_pred_neg = K.concatenate([y_pred_neg, zeros], axis=-1)
y_pred_pos = K.concatenate([y_pred_pos, zeros], axis=-1)
neg_loss = K.logsumexp(y_pred_neg, axis=-1)
pos_loss = K.logsumexp(y_pred_pos, axis=-1)
return neg_loss + pos_loss
所以,结论就是
所以,最终结论就是式 (11),它就是本文要寻求的多标签分类问题的统一 loss,欢迎大家测试并报告效果。笔者也实验过几个多标签分类任务,均能媲美精调权重下的二分类方案。
要提示的是,除了标准的多标签分类问题外,还有一些常见的任务形式也可以认为是多标签分类,比如基于 0/1 标注的序列标注,典型的例子是笔者的“半指针-半标注”标注设计。
因此,从这个角度看,能被视为多标签分类来测试式 (11) 的任务就有很多了,笔者也确实在之前的三元组抽取例子 task_relation_extraction.py [6] 中尝试了 (11),最终能取得跟这里 [7] 一致的效果。
当然,最后还是要说明一下,虽然理论上式 (11) 作为多标签分类的损失函数能自动地解决很多问题,但终究是不存在绝对完美、保证有提升的方案。
参考链接
[1] https://kexue.fm/archives/3290
[2] https://kexue.fm/archives/6620
[3] https://kexue.fm/archives/4733
[4] https://en.wikipedia.org/wiki/Newton's_identities
[5] https://arxiv.org/abs/2002.10857
[6] https://github.com/bojone/bert4keras/blob/master/examples/task_relation_extraction.py
[7] https://kexue.fm/archives/7161#类别失衡
点击以下标题查看更多往期内容:
当深度学习遇上量化交易——因子挖掘篇 BERT在小米NLP业务中的实战探索 Open Images冠军方案:商汤TSD目标检测算法解读 如何理解用户评论中的细粒度情感? EAE:自编码器 + BN + 最大熵 = 生成模型 最新综述 | 强化学习中从仿真器到现实环境的迁移
#投 稿 通 道#
让你的论文被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学习心得或技术干货。我们的目的只有一个,让知识真正流动起来。
📝 来稿标准:
• 稿件确系个人原创作品,来稿需注明作者个人信息(姓名+学校/工作单位+学历/职位+研究方向)
• 如果文章并非首发,请在投稿时提醒并附上所有已发布链接
• PaperWeekly 默认每篇文章都是首发,均会添加“原创”标志
📬 投稿邮箱:
• 投稿邮箱:hr@paperweekly.site
• 所有文章配图,请单独在附件中发送
• 请留下即时联系方式(微信或手机),以便我们在编辑发布时和作者沟通
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧
关于PaperWeekly
PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。